# Reksio - Memory Map Editor
# Copyright (C) 2023 CERN
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
# In applying this licence, CERN does not waive the privileges and immunities
# granted to it by virtue of its status as an Intergovernmental Organization or
# submit itself to any jurisdiction.

try:
    import reksio
except ImportError:
    # Non-GUI mode
    import logging
    reksio = logging

import ruamel.yaml

from ruamel.yaml.comments import CommentedMap
from ruamel.yaml.comments import merge_attrib
from copy import deepcopy
from tree import TreeView

from collections import OrderedDict
from collections import defaultdict

def upgrade_all_to_1_0_0(tree: TreeView):
    def remove_anchors_and_merge(data):
        if isinstance(data, dict):
            anchors = []
            for k, v in data.items():
                if hasattr(v, "yaml_anchor") and v.yaml_anchor():
                    anchors.append(k)
                for merge_data in getattr(v, "merge", []):
                    v.update(deepcopy(merge_data[1]))
                    delattr(v, merge_attrib)
            for k in anchors:
                del data[k]
            for v in data.values():
                remove_anchors_and_merge(v)
        elif isinstance(data, list):
            for child in data:
                remove_anchors_and_merge(child)
    versions = {
        'core': "1.0.0",
        'x-gena': "1.0.0",
        'x-hdl': "1.0.0",
        'x-fesa': "1.0.0",
        'x-driver-edge': "1.0.0",
        'x-conversions': "1.0.0",
        "x-wbgen": "1.0.0"
    }
    root_tree = tree.attributes
    no_elements = len(root_tree)
    root_tree.insert(no_elements-1, 'schema-version', versions)
    remove_anchors_and_merge(tree.tree)


def upgrade_x_map_info_to_1_0_0(tree: TreeView):
    # sets memory-map/schema-version/x-map-info to 1.0.0
    #
    # Move attributes:
    # memory-map/x-cern-info/map-version                -> memory-map/x-gena/map-version
    # memory-map/x-cern-info/ident-code                 -> memory-map/x-map-info/ident
    # memory-map/x-cern-info/semantic-mem-map-version   -> memory-map/x-map-info/memmap-version
    # memory-map/x-gena/ident-code                      -> memory-map/x-map-info/ident
    # memory-map/x-gena/semantic-mem-map-version        -> memory-map/x-map-info/memmap-version
    # submap/x-cern-info/semantic-mem-map-version       -> submap/x-map-info/memmap-version
    #
    # Remove attributes (also from schema):
    # memory-map/x-cern-info/ident-code
    # memory-map/x-cern-info/map-version
    # memory-map/x-cern-info/semantic-mem-map-version
    # submap/x-cern-info/semantic-mem-map-version
    # memory-map/x-gena/semantic-mem-map-version
    # memory-map/x-gena/ident-code

    # set new version
    schema_version = tree.get('schema-version')
    schema_version['x-map-info'] = "1.0.0"

    cern_info = tree.get('x-cern-info', None)
    if cern_info is not None:
        # handle memory-map/x-cern-info

        map_version = cern_info.get('map-version', None)
        if map_version is not None:
            gena = tree.setdefault('x-gena', {})
            gena['map-version'] = map_version

        ident_code = cern_info.get('ident-code', None)
        if ident_code is not None:
            map_info = tree.setdefault('x-map-info', {})
            map_info['ident'] = ident_code

        semantic_mem_map_version = cern_info.get('semantic-mem-map-version', None)
        if semantic_mem_map_version is not None:
            map_info = tree.setdefault('x-map-info', {})
            map_info['memmap-version'] = semantic_mem_map_version

        del tree['x-cern-info']

    gena = tree.get('x-gena', None)
    if gena is not None:
        # handle memory-map/x-gena

        ident_code = gena.get('ident-code', None)
        if ident_code is not None:
            map_info = tree.setdefault('x-map-info', {})
            map_info['ident'] = ident_code
            del gena['ident-code']

        semantic_mem_map_version = gena.get('semantic-mem-map-version', None)
        if semantic_mem_map_version is not None:
            map_info = tree.setdefault('x-map-info', {})
            map_info['memmap-version'] = semantic_mem_map_version
            del gena['semantic-mem-map-version']

    # submaps

    submaps = [node for node in tree.walk_pre_order() if node.node_type == "submap"]
    for submap in submaps:
        cern_info = submap.get('x-cern-info', None)
        if cern_info is not None:
            semantic_mem_map_version = cern_info.get('semantic-mem-map-version', None)
            if semantic_mem_map_version is not None:
                map_info = submap.setdefault('x-map-info', {})
                map_info['memmap-version'] = semantic_mem_map_version
            del submap['x-cern-info']


def upgrade_core_to_2_0_0(tree: TreeView):
    # sets memory-map/schema-version/core to 2.0.0
    #
    # remove "note" from:
    # memory-map
    # code-field
    # reg
    # array
    # memory
    # submap
    # block
    # and merge it with "description" if it's different
    affected_node_types = ['memory-map', 'code-field', 'reg', 'array', 'memory', 'submap', 'block']

    nodes = [node for node in tree.walk_pre_order() if node.node_type in affected_node_types]

    for node in nodes:
        note = node.get('note', None)
        if note is not None:
            stripped_note = str(note).strip()
            description = node.get('description', None)
            if description is None:
                if len(stripped_note):
                    # create description and put note there
                    node['description'] = stripped_note
            else:
                # description present
                stripped_description = str(description).strip()
                if len(stripped_note) and stripped_note != stripped_description:
                    node['description'] = stripped_description + " " + stripped_note
        if 'note' in node:
            del node['note']

    # set new version
    schema_version = tree.get('schema-version')
    schema_version['core'] = "2.0.0"


def upgrade_x_driver_edge_to_2_0_0(tree: TreeView):
    """
    Upgrade for Address-Spaces

    Sets memory-map/schema-version/x-driver-edge to 2.0.0

    The conversion of the x-driver-edge is implemented in the core to v3.0.0 function, 
    because the Address-Space and x-driver-edge/pci-bars updates cannot be done separately.
    """
    schema_version = tree.get('schema-version')
    schema_version['x-driver-edge'] = "2.0.0"


def upgrade_x_fesa_to_2_0_0(tree: TreeView):
    # sets memory-map/schema-version/x-fesa to 2.0.0
    # split: x-fesa/persistence to:
    # x-fesa/multiplexed = true, if persistence == PPM; false otherwise
    # x-fesa/persistence = true, if persistence
    # concerns: reg, array, memory

    affected_node_types = ['reg', 'array', 'memory']
    nodes = [node for node in tree.walk_pre_order() if node.node_type in affected_node_types]
    for node in nodes:
        x_fesa = node.get('x-fesa', None)
        if x_fesa is not None:
            # look for persistence
            persistence = x_fesa.get('persistence', None)
            if persistence is not None:
                if persistence == "PPM":
                    # PPM means it's multiplexed and persistent
                    x_fesa['multiplexed'] = True
                    x_fesa['persistence'] = True
                elif persistence == "Fesa":
                    # not multiplexed and persistent
                    x_fesa['persistence'] = True
                    x_fesa['multiplexed'] = False
                elif persistence == "None":
                    # not persistent
                    x_fesa['persistence'] = False
                    x_fesa['multiplexed'] = False

    # set new version
    schema_version = tree.get('schema-version')
    schema_version['x-fesa'] = "2.0.0"


def upgrade_x_gena_to_2_0_0(tree: TreeView):
    # code-fields implementation as x-enums
    # concerns: field & reg (can have code-field as a child)
    # move every code field to x-enums

    # helper functions
    def get_full_name(node):
        name = []
        if node.parent is None:
            return node["name"]
        name.append(node['name'])
        parent = node.parent
        while getattr(parent, 'parent', None):
            name.insert(0, parent['name'])
            parent = parent.parent
        return "_".join(name)

    def get_width(node):
        if node.node_type == "reg":
            width = int(node['width'])
            # check if RMW
            x_gena = node.get('x-gena', None)
            if x_gena is None:
                return width
            rmw = x_gena.get('rmw', 'false')
            if rmw == 'true':
                return width // 2
            return width
        # get parent
        elif node.node_type == "field":
            bit_range = node['range']
            if '-' in str(bit_range):
                hi, lo = map(int, bit_range.split('-'))
                return hi - lo + 1
            else:
                return 1

    # set new version
    schema_version = tree.get('schema-version')
    schema_version['x-gena'] = "2.0.0"
    schema_version['x-enums'] = "1.0.0"

    affected_node_types = ['field', 'reg']
    nodes = [node for node in tree.walk_pre_order() if node.node_type in affected_node_types]

    for node in nodes:
        x_gena = node.get('x-gena', None)
        if x_gena is not None:
            code_fields = x_gena.get('code-fields', None)
            if code_fields is not None:
                root_tree = tree.attributes
                if 'x-enums' not in root_tree:
                    # create top level x-enums
                    no_elements = len(root_tree)
                    root_tree.insert(0, 'x-enums', [])
                root_x_enums = tree['x-enums']
                # create enum type
                enum_name = get_full_name(node)
                enum = {
                    'name': enum_name,
                    'width': get_width(node),
                    'children': []
                }

                for cf in code_fields:
                    cf = cf['code-field']
                    enum_item = {
                        'name': cf['name'],
                        'value': cf['code']
                    }
                    if 'comment' in cf and cf['comment'] is not None and len(str(cf['comment']).strip()):
                        enum_item['comment'] = cf['comment']
                    if 'description' in cf and cf['description'] is not None and cf and len(str(cf['description']).strip()):
                        enum_item['description'] = cf['description']
                    enum['children'].append({'item': enum_item})
                root_x_enums.append({'enum': enum})
                del x_gena['code-fields']
                if not len(x_gena):
                    del node['x-gena']

                node_x_enums = node.setdefault('x-enums', {})
                node_x_enums['name'] = enum_name

def _is_map(node) -> bool:
    return node and node.node_type == 'memory-map'

def _is_submap(node) -> bool:
    return node and node.node_type == 'submap'

def upgrade_core_to_3_0_0(tree: TreeView):
    '''
    sets memory-map/schema-version/core to 3.0.0
    adds corresponding address-spaces for existing memory-map/x-driver-edge/pci-bars
    memory-map/x-driver-edge/default-pci-bar-name is used for all submaps that have not had pci-bar assigned
    converts submap/x-driver-edge/pci-bar-name to submap/address-space
    update existing PCI bars definitions and create corresponding address spaces

    Address-Space and x-driver-edge/pci-bars updates cannot be separated and must be performed in the same function
    '''
    def _is_upgradable(node) -> bool:
        '''
        For the time being, only nodes being submaps are upgradable to address-spaces
        '''
        return node and node.node_type in ["submap"]

    def _get_pci_bars(submaps, pci_bar_default=None) -> defaultdict:
        '''
        Create a dictionary that groups submaps by each found BAR name,
        in order to create new Address-Spaces.

        E.g. the output dictionary:
        {
            'bar0': [
                TreeView(submap, parent=TreeView(memory-map, parent=None)): {'name': 'hwInfo', ...},
                TreeView(submap, parent=TreeView(memory-map, parent=None)): {'name': 'sis8300ku', ...},
                TreeView(submap, parent=TreeView(memory-map, parent=None)): {'name': 'app', ...},
                ...
            ],
            'bar4': [
                TreeView(submap, parent=TreeView(memory-map, parent=None)): {'name': 'fgc_ddr', ...},
                TreeView(submap, parent=TreeView(memory-map, parent=None)): {'name': 'acq_ddr', ...},
                TreeView(submap, parent=TreeView(memory-map, parent=None)): {'name': 'acq_ram', ...},
                ...,
            ],
            'barX': [...],
        }
        '''
        pci_bars = defaultdict(lambda: [])
        for submap in submaps:
            xedge = submap.attributes.get('x-driver-edge')
            pci_bar = xedge and xedge.get('pci-bar-name', None)

            parent_xedge = submap.parent.attributes.get('x-driver-edge')
            parent_pci_bars = parent_xedge and parent_xedge.get('pci-bars')

            # remove deprecated attribute
            if not parent_pci_bars and pci_bar:
                del submap.attributes['x-driver-edge']['pci-bar-name']

            # if x-driver-edge not defined in the submap's parent,
            # do not convert any submap
            if not parent_pci_bars:
                continue

            # conversion is not possible when the parent has pci-bars defined and its childen don't
            if not pci_bar and not pci_bar_default:
                parent_name = submap.parent.attributes.get('name')
                sb_name = submap.attributes.get('name')
                err_msg = f'Submap "{sb_name}": "pci-bar-name" not defined and "default-pci-bar-name" in its parent "{parent_name}" - convertion not possible!'
                reksio.critical(err_msg)
                raise Exception(err_msg)

            # remove deprecated attribute anyway
            if pci_bar:
                del submap.attributes['x-driver-edge']['pci-bar-name']

            bar = pci_bar or pci_bar_default
            pci_bars[bar] += [submap]

        return pci_bars

    def _create_address_spaces(mem_map, pci_bars):
        '''
        Create Address-Spaces based on the pci_bars dictionary with grouped submaps by BAR name (look -> _get_pci_bars() function)
        '''
        mem_map.attributes['children'] = []

        root_xedge = mem_map.attributes.get('x-driver-edge')
        root_pci_bars = root_xedge and root_xedge.get('pci-bars')

        for bar_name, submap in pci_bars.items():
            addr_space_children = CommentedMap({
                    'name': bar_name,
                    'children': [CommentedMap({'submap': sb.attributes}) for sb in submap]
            })

            address_space = CommentedMap({'address-space': addr_space_children})
            mem_map.attributes['children'] += [address_space]

            # create and add new pci-bar the top x-driver-edge/pci-bars
            if not root_pci_bars:
                pci_bar_attrs = CommentedMap({
                        'name': bar_name,
                        'address-space': bar_name
                })

                # set to 0x0 for the global top map only (compatibility with EDGE driver generator)
                if mem_map.is_global_root():
                    pci_bar_attrs.update({'base-addr': '0x0'})

                new_pci_bar = CommentedMap({'pci-bar': pci_bar_attrs})
                mem_map.attributes['x-driver-edge']['pci-bars'].update(new_pci_bar)
                continue

            # Update top x-driver-edge/pci-bars if they exists
            for pci_bar in mem_map.attributes['x-driver-edge']['pci-bars']:
                attrs = pci_bar['pci-bar']
                if attrs['name'] != bar_name:
                    continue

                attrs.update({'address-space': bar_name})
                if mem_map.is_global_root():
                    attrs.update({'base-addr': '0x0'})

    # Get memory map and its all nodes to be upgraded
    # The memory map is root in this local scope. In order to get the global root map, use the 'is_global_root()' method
    mem_map, *_ = [node for node in tree.walk_pre_order() if _is_map(node)]
    upgradable_nodes = [node for node in mem_map.children if _is_upgradable(node)]

    mem_map_xedge = mem_map.attributes.get('x-driver-edge')
    pci_bar_default = mem_map_xedge and mem_map_xedge.get('default-pci-bar-name')
    pci_bars = _get_pci_bars(upgradable_nodes, pci_bar_default)

    # update schema version
    schema_version = mem_map.get('schema-version')
    schema_version['core'] = '3.0.0'

    # remove default PCI bar if exists, and no PCI bars are defined
    if not pci_bars and pci_bar_default:
        del mem_map.attributes['x-driver-edge']['default-pci-bar-name']

    # Do update only when PCI bars are defined
    if not pci_bars:
        mem_map_name = mem_map.attributes.get('name', 'MEMMAP_NAME_NOT_DEFINED')
        reksio.debug(f'upgrade_core_to_3_0_0/{mem_map_name}: PCI BARs not defined, skip upgrade and continue...')
        return

    _create_address_spaces(mem_map, pci_bars)

    # remove default PCI-Bar, no needed any longer
    if pci_bar_default:
        del mem_map.attributes['x-driver-edge']['default-pci-bar-name']


def upgrade_x_driver_edge_to_3_0_0(tree: TreeView):
    """
    Upgrade for the base address of Address-Space nodes (hardcoded to 0x0), removing  base address
    attribute from PCIBars from top memory map children 'x-driver-edge' (do not confuse with container
    attribute with same name).

    Moreover, 'x-driver-edge/schema-version' has been renamed to 'x-driver-edge/edge-version'.

    Sets memory-map/schema-version/x-driver-edge to 3.0.0
    """
    def _upgrade_pci_bars(mem_map):
        '''
        Removed base address child from PCI-Bars which are stored as:
        {
            'pci-bar':
            {
                'name': 'bar0',
                'number': 0,
                [...],
                'address-space': 'bar0'
            }
        },
        [...]
        {
            'pci-bar':
            {
                'name': 'bar4',
                'number': 4,
                [...],
                'address-space': 'bar4',
                'base-addr': '0x0'
            }
        }
        '''
        pci_bars = mem_map.get('x-driver-edge') and mem_map.get('x-driver-edge').get('pci-bars')
        if not pci_bars:
            reksio.debug(f'upgrade_x_driver_edge_to_3_0_0/{mem_map_name}: PCI BARs not defined, skip upgrade and continue...')
            return

        for pcibar in pci_bars:
            _pcibar = pcibar.get('pci-bar') # extra access layer - look comment up
            _pcibar.pop('base-addr', None) # remove base address without raising KeyError

    schema_version = tree.get('schema-version')
    schema_version['x-driver-edge'] = "3.0.0"

    all_memory_maps = [node for node in tree.walk_pre_order() if _is_map(node)]
    for mem_map in all_memory_maps:
        mem_map_name = mem_map.get('name', 'MEMMAP_NAME_NOT_DEFINED') + ".cheby"

        # rename 'x-driver-edge/schema-version' into 'x-driver-edge/edge-version'
        xedge = mem_map.attributes.get('x-driver-edge')
        if xedge and xedge.get('schema-version'):
            reksio.debug(f'upgrade_x_driver_edge_to_3_0_0/{mem_map_name}/x-driver-edge: "schema-version" renamed to "edge-version"')
            edge_ver = xedge.pop('schema-version')
            if not str(edge_ver) in ['2.1', '3.1']: # TODO: update to PyCheby enum values
                reksio.warn(f'upgrade_x_driver_edge_to_3_0_0/{mem_map_name}: "x-driver-edge/schema-version"(={edge_ver}) not valid, set to default 3.1, continue...')
                edge_ver = ruamel.yaml.scalarfloat.ScalarFloat(3.1) # Type comes from PyCheby parser

            xedge['edge-version'] = edge_ver

        _upgrade_pci_bars(mem_map)

